tutorialのLSTMの例をPyTorch Lightningで書き換える試み
PyTorch Lightningで書き換え
LightningDataModuleを作るのが少し大変
prepare_sequenceすると長さの揃わないtorch.Tensor
code:python
training_data = [
(
"The dog ate the apple".lower().split(),
"DET", "NN", "V", "DET", "NN",
),
("Everybody read that book".lower().split(), "NN", "V", "DET", "NN"),
]
# 辞書 word_to_ix, tag_to_ix を使って、整数に変換
>> sentence_tensors
[tensor(0, 1, 2, 0, 3), tensor(4, 5, 6, 7)]
>> tags_tensors
[tensor(0, 1, 2, 0, 1), tensor(1, 2, 0, 1)]
長さの揃わないTensorのリストからまとめたTensorが作れない
code:python
>> # torch.tensor(sentence_tensors0) や torch.tensor(sentence_tensors1) (1要素の場合)はOK
>> torch.tensor(sentence_tensors)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
TypeError: only integer tensors of a single element can be converted to an index
>> # tensorのlistの代わりに、listのlistから作れるか試した
>> # 参考: https://pytorch.org/docs/stable/tensors.html#initializing-and-basic-operations
>> torch.tensor([sentence_tensors0.tolist(), sentence_tensors1.tolist()])
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
ValueError: expected sequence of length 5 at dim 1 (got 4)
解決策:IterableDatasetにまとめる
https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset
__iter__を実装する
code:python
>> class MyIterableDataset(IterableDataset):
... def __init__(self):
... super(MyIterableDataset, self).__init__()
... def __iter__(self):
... return zip(sentence_tensors, tags_tensors)
...
>> ds = MyIterableDataset()
>> list(ds)
[(tensor(0, 1, 2, 0, 3), tensor(0, 1, 2, 0, 1)), (tensor(4, 5, 6, 7), tensor(1, 2, 0, 1))]
>> list(DataLoader(ds))
tensor([[0, 1, 2, 0, 3), tensor(0, 1, 2, 0, 1)], [tensor(4, 5, 6, 7), tensor(1, 2, 0, 1)]]
DataLoaderではTensorに次数が追加されている
https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader
デフォルトがbatch_size=1のため
Tensorのsizeが異なるため、batch_size=2は指定できない(list(DataLoader(ds, batch_size=2)))
RuntimeError: stack expects each tensor to be equal size, but got [5] at entry 0 and [4] at entry 1
code:python
>> list(ds)00.size()
torch.Size(5)
>> list(DataLoader(ds))00.size()
torch.Size(1, 5)
上の挙動により、LSTMのネットワークではword_embeddingslayerに(1, 5), (1, 4)というsizeで入力されてしまう
viewメソッド呼び出しで、Tensorの次数追加を無効にする必要あり
LSTM layerの入力の第2引数が、ミニバッチを考えていないので1(PyTorch tutorial: LSTM)
💡ミニバッチを考えるなら、長さを揃えるためのpadding
pad_sequenceを使ってtutorialのLSTMの例を書き換える試みへ続く